import pandas as pd
from tqdm import tqdm
import torch
from transformers import pipeline, AutoTokenizer
import numpy as np
import gc
import warnings
from transformers import logging
logging.set_verbosity_error()
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

llama_key = open('./llama_key.txt', 'r').read()

device = 'cuda'
# Data import and clean
motn = pd.read_csv('./data/freedom_test.csv', encoding = "ISO-8859-1")
motn = motn[~motn['premise'].isna()]
motn.reset_index(drop = True, inplace = True)

# Instantiate model
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
pipe = pipeline(
    "text-generation",
    model=model_id,
    model_kwargs={"torch_dtype": torch.bfloat16},
    device_map=device,
    token = llama_key
)

#############
## Label Docs
#############
shots = [25]
shot_dfs = {}

for shot in shots:
    # get the examples
    shots_df = motn.groupby('entailment', group_keys=False).apply(lambda x: x.sample(n=int(np.floor(shot / 2)) if x.name == motn['entailment'].unique()[0] else int(np.ceil(shot / 2)), random_state=1))
    # redefine df to not have the examples
    test_df = motn.loc[~motn.index.isin(shots_df.index)].copy()
    # format examples
    examples = "\n".join(shots_df['premise'].astype(str) + '\nLabel: ' + shots_df['entailment'].astype(str))
    
    # user message
    user_message = f"""You are a classifier that determines if a survey answer is about freedom and rights. The survey answer I will show is a response to the question "What does democracy mean to you?" 
    Answers about freedom and rights reference "freedom" or "liberty" in the abstract, or reference a particular freedom such as the freedom to express one's opinion, freedom of religion, or other rights mentioned in the Bill of Rights. They may also contain references to freedom from government, freedom of speech, the Second Amendment, references to civil rights or substantive rights, and references to capitalism or free enterprise. 
    However, the answer is not about freedom and rights if the answer is about voting rights, the "right" to vote, or anything having to do with voting. 
    
    Here are some examples:
    {examples}
    
    Now I will show you an answer to the question. Read the answer and then determine if the answer is about freedom and rights or not. Here is the respondent's answer:
    "{{}}"
    If the answer is about freedom and rights, return 0. If it is not about freedom and rights, return 1. Do not explain your answer, and only return the number.
    """
    
    res = []
    for doc in test_df['premise']:
        messages = [{"role": "user", "content": user_message.format(doc)}]
        prompt = pipe.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        outputs = pipe(prompt, max_new_tokens=2, do_sample=False, return_full_text=False, pad_token_id=pipe.tokenizer.eos_token_id)
        res.extend(outputs)
    
    # get results
    res = [text['generated_text'] for text in res]
    res = [num[0] for num in res]
    
    # put into df
    test_df[f'llama_{shot}shot'] = [1 if '1' in text else 0 for text in res]
    shot_dfs[shot] = test_df
    
    # print progress
    print(f"finished {shot} shots")
    gc.collect()

# Export
fewshot = shot_dfs[25]
fewshot.to_csv('./data/llama_motn_25shot.csv', index = False)